{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../../../FinNLP/')  # https://github.com/AI4Finance-Foundation/FinNLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2023-08-04 17:08:33,847] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizerFast   # 4.30.2\n",
    "from peft import PeftModel  # 0.4.0\n",
    "import torch\n",
    "\n",
    "from finnlp.benchmarks.fpb import test_fpb\n",
    "from finnlp.benchmarks.fiqa import test_fiqa , add_instructions\n",
    "from finnlp.benchmarks.tfns import test_tfns\n",
    "from finnlp.benchmarks.nwgi import test_nwgi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Model (Pick one according the model from the following blocks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1. **FInGPT v3.1** based on ChatGLM2, runable on 1 * RTX 3090"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# v3.1\n",
    "base_model = \"THUDM/chatglm2-6b\"\n",
    "peft_model = \"oliverwang15/FinGPT_v31_ChatGLM2_Sentiment_Instruction_LoRA_FT\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)\n",
    "model = AutoModel.from_pretrained(base_model, trust_remote_code=True, load_in_8bit = True, device_map = \"auto\")\n",
    "model = PeftModel.from_pretrained(model, peft_model)\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. **FinGPT v3.2** based on Llama2, runable on 1 * A 100 and also runable on 1 * RTX 3090 as long as `load_in_8bit = True` is set in Line 8, but the speed is slower"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "21d8b449e1734815b004b8532621884a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# v3.2\n",
    "# base_model = \"meta-llama/Llama-2-7b-chat-hf\"  # Access needed\n",
    "base_model = \"daryl149/llama-2-7b-chat-hf\"   \n",
    "peft_model = \"oliverwang15/FinGPT_v32_Llama2_Sentiment_Instruction_LoRA_FT\"\n",
    "tokenizer = LlamaTokenizerFast.from_pretrained(base_model, trust_remote_code=True)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "model = LlamaForCausalLM.from_pretrained(base_model, trust_remote_code=True, device_map = \"cuda:0\")\n",
    "model = PeftModel.from_pretrained(model, peft_model)\n",
    "model = torch.compile(model)  # Please comment this line if your platform does not support torch.compile\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset financial_phrasebank (/xfs/home/tensor_zy/.cache/huggingface/datasets/financial_phrasebank/sentences_50agree/1.0.0/550bde12e6c30e2674da973a55f57edde5181d53f5a5a34c1531c53f93b7e141)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e51735f248a54fee9d8d06a62eaccfb9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached split indices for dataset at /xfs/home/tensor_zy/.cache/huggingface/datasets/financial_phrasebank/sentences_50agree/1.0.0/550bde12e6c30e2674da973a55f57edde5181d53f5a5a34c1531c53f93b7e141/cache-5054f1bbfc7f1863.arrow and /xfs/home/tensor_zy/.cache/huggingface/datasets/financial_phrasebank/sentences_50agree/1.0.0/550bde12e6c30e2674da973a55f57edde5181d53f5a5a34c1531c53f93b7e141/cache-c9d8592c85fcadbe.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Prompt example:\n",
      "Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: L&T has also made a commitment to redeem the remaining shares by the end of 2011 .\n",
      "Answer: \n",
      "\n",
      "\n",
      "Total len: 1212. Batchsize: 8. Total steps: 152\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 152/152 [02:00<00:00,  1.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Acc: 0.8506600660066007. F1 macro: 0.8396686677078228. F1 micro: 0.8506600660066006. F1 weighted (BloombergGPT): 0.8500004920027702. \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# FPB\n",
    "res = test_fpb(model, tokenizer, batch_size = batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset csv (/xfs/home/tensor_zy/.cache/huggingface/datasets/pauri32___csv/pauri32--fiqa-2018-0ac3134eda915d9d/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9931e3eebaf541e5af591300ce0a0c79",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached split indices for dataset at /xfs/home/tensor_zy/.cache/huggingface/datasets/pauri32___csv/pauri32--fiqa-2018-0ac3134eda915d9d/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-a47229243cae4784.arrow and /xfs/home/tensor_zy/.cache/huggingface/datasets/pauri32___csv/pauri32--fiqa-2018-0ac3134eda915d9d/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-09667165ea545971.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Prompt example:\n",
      "Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: This $BBBY stock options trade would have more than doubled your money https://t.co/Oa0loiRIJL via @TheStreet\n",
      "Answer: \n",
      "\n",
      "\n",
      "Total len: 275. Batchsize: 8. Total steps: 35\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 35/35 [00:23<00:00,  1.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Acc: 0.8436363636363636. F1 macro: 0.7524122647015905. F1 micro: 0.8436363636363636. F1 weighted (BloombergGPT): 0.860437084621201. \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# FiQA\n",
    "res = test_fiqa(model, tokenizer, prompt_fun = add_instructions, batch_size = batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset csv (/xfs/home/tensor_zy/.cache/huggingface/datasets/zeroshot___csv/zeroshot--twitter-financial-news-sentiment-467590c9f24f65ad/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8d351d01c91c4df2b6736cebd2bfc96b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Prompt example:\n",
      "Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $ALLY - Ally Financial pulls outlook https://t.co/G9Zdi1boy5\n",
      "Answer: \n",
      "\n",
      "\n",
      "Total len: 2388. Batchsize: 8. Total steps: 299\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 299/299 [03:32<00:00,  1.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Acc: 0.8936348408710217. F1 macro: 0.866430480618281. F1 micro: 0.8936348408710217. F1 weighted (BloombergGPT): 0.8937510370439953. \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# TFNS\n",
    "res = test_tfns(model, tokenizer, batch_size = batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset parquet (/xfs/home/tensor_zy/.cache/huggingface/datasets/oliverwang15___parquet/oliverwang15--news_with_gpt_instructions-ec641c48430028ee/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d2aadd130768450989f9c73a1a00292f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Prompt example:\n",
      "Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: In the latest trading session, Adobe Systems (ADBE) closed at $535.98, marking a +0.31% move from the previous day.\n",
      "Answer: \n",
      "\n",
      "\n",
      "Total len: 4047. Batchsize: 8. Total steps: 506\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 506/506 [07:55<00:00,  1.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Acc: 0.6360266864343959. F1 macro: 0.6443667929544722. F1 micro: 0.6360266864343959. F1 weighted (BloombergGPT): 0.6324602746076219. \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# NWGI\n",
    "res = test_nwgi(model, tokenizer, batch_size = batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fingpt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
